FROM ghcr.io/nvidia/jax:jax-2024-03-08

RUN apt-get update && \
    apt-get install -y --no-install-recommends python3-pip && \
    apt-get clean && \
    rm -rf /var/lib/apt/lists/* && \
    pip install --no-cache-dir transformers sentencepiece simple_parsing datasets orbax==0.4.8 && \
    pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu